#! /usr/bin/env python
#coding=utf-8

import os
import sys
import sqlite3
import concurrent.futures
import threading
import traceback

from dylib_file import DylibFile
from dependency import Dependency
from tablebuilder import CallsTableBuilder, SymbolsTableBuilder, DepedenceiesTableBuilder, ModuleTableBuilder

sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../"))
from macho import MachOFileMgr

#https://superfastpython.com/threadpoolexecutor-thread-local/

class ModuleMgr(MachOFileMgr):
    def __init__(self, dylib_shared_cache=None, out_path=None):
        super(ModuleMgr, self).__init__(dylib_shared_cache, DylibFile, Dependency)
        if out_path:
            _out_path = out_path
        else:
            _out_path = self.get_root_path()
        self._db_file = os.path.join(_out_path, "archinfo.db")
        self._conn = sqlite3.connect(self._db_file)
        self._cursor = self._conn.cursor()

    def threadpool_initializer(local, mgr, type):
        local.conn, local.cursor = mgr.connectDB()
        log_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "log")
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        #fileName = os.path.join(log_path, type+"-"+str(threading.get_ident())+".log")
        #local.log_file = open(fileName, "w")
        local.log_file = None

    def __threadpool_task_wrapper(task, local, mod):
        try:
            task(local, mod)
            return True
        except Exception as e:
            print("========================")
            print(e)
            traceback.print_stack()
            return False

    def threadpool_do_tasks(self, task, type):
        local = threading.local()
        files = self.get_all()
        # https://superfastpython.com/processpoolexecutor-in-python/
        # with concurrent.futures.ThreadPoolExecutor(max_workers=6, initializer = thread_initializer, initargs=self) as executor:
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=12, initializer=ModuleMgr.threadpool_initializer,
                                                         initargs=(local, self, type))
        futures = [executor.submit(ModuleMgr.__threadpool_task_wrapper, task, local, mod) for mod in files]
        # done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.ALL_COMPLETED)
        # print(done)
        # print(not_done)
        failCnt = 0
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            if not result:
                failCnt = failCnt + 1
        executor.shutdown()

        print("Do parallel job %s finished with %d failures." % (type, failCnt))
        return failCnt

    def __add_symbols_task(local, mod):
        #mgr = mod["mgr"]
        if local.log_file:
            local.log_file.write("\n\n")
        syms, undefines = mod.addSymbols(local.cursor, local.log_file)
        if local.log_file:
            local.log_file.write("Add %d symbols %d undefines for %d:%s finished.\n" % (syms, undefines, mod["id"], mod["name"]))
        return syms, undefines, mod

    def add_all_symbols(self):
        DylibFile.createSymbolTables(self._cursor)
        self._cursor.execute("PRAGMA journal_mode = OFF")
        self._cursor.execute("PRAGMA synchronous = OFF")

        self.threadpool_do_tasks(ModuleMgr.__add_symbols_task, "symbols")

        # create symbols_count view
        self._cursor.execute("drop view if exists symbols_count")
        sqlcmd = "create view IF NOT EXISTS [symbols_count] as SELECT symbols.parent_id, symbols.name, symbols.addr, symbols.bits, symbols.name_type, symbols.section, symbols.id, calls_count.count from symbols LEFT OUTER join calls_count on symbols.id=calls_count.symbol_id"
        self._cursor.execute(sqlcmd)

        self._cursor.execute("DROP INDEX IF EXISTS index_symbols")
        self._cursor.execute("CREATE INDEX index_symbols ON symbols (name)")
        self._cursor.execute("DROP INDEX IF EXISTS index_undefines")
        self._cursor.execute("CREATE INDEX index_undefines ON undefines (name)")

    def __query_calls_task(local, mod):
        mod.queryCalls(local.cursor)

    def __add_calls_task(local, mod):
        mod.addCalls(local.cursor)

    def add_all_calls(self):
        builder = CallsTableBuilder(self._cursor)
        builder.createTable()

        self._cursor.execute("PRAGMA journal_mode = OFF")
        self._cursor.execute("PRAGMA synchronous = OFF")

        self.threadpool_do_tasks(ModuleMgr.__query_calls_task, "query_calls")

        self.threadpool_do_tasks(ModuleMgr.__add_calls_task, "calls")

        self._cursor.execute("drop view if exists call_details")
        symbolCols = ["syms.%s as %s" % (c, c) for c in SymbolsTableBuilder(self._cursor).getColumns()]
        callsCols = ["calls.%s as %s" % (c, c) for c in builder.getColumns()]
        self._cursor.execute("create view [call_details] as select %s, callers.name as caller, %s from calls left outer join symbols syms on syms.id = calls.symbol_id left outer join modules callers on callers.id = calls.caller_id left outer join dependencies deps on deps.id=calls.dependence_id" % (", ".join(symbolCols), ", ".join(callsCols)))

        self._cursor.execute("drop view if exists calls_count")
        sqlcmd = "create view IF NOT EXISTS [calls_count] as select symbol_id, count(caller_id) as count from calls group by symbol_id"
        self._cursor.execute(sqlcmd)

        self._cursor.execute("drop view if exists calls_count_details")
        sqlcmd = "create view IF NOT EXISTS [calls_count_details] as select calls_count.symbol_id, calls_count.count, symbols.name from calls_count left outer join symbols on calls_count.symbol_id=symbols.id"
        self._cursor.execute(sqlcmd)

    def __add_deps_task(local, mod):
        mod.addDeps(local.cursor)

    def add_all_dependencies(self):
        builder = DepedenceiesTableBuilder(self._cursor)
        builder.createTable()

        self._cursor.execute("PRAGMA journal_mode = OFF")
        self._cursor.execute("PRAGMA synchronous = OFF")

        self.threadpool_do_tasks(ModuleMgr.__add_deps_task, "deps")

        self._cursor.execute("drop view if exists depends")
        self._cursor.execute("drop view if exists calledby")

        dependenciesCols = ["dependencies.%s as %s " % (c, c) for c in builder.getColumns()]
        self._cursor.execute("create view [depends] as select %s, callers.name as caller, calledbys.name as callee from dependencies left outer join modules callers on callers.id = dependencies.caller_id left outer join modules calledbys on calledbys.id = dependencies.callee_id" % ", ".join(dependenciesCols))

    def __update_modules_info(local, mod):
        print("Update %d:%s module info now ..." % (mod["id"], mod["name"]))

        mod.extract_elf_size()
        mod.updateCountsInfo(local.cursor)

    def add_all_modules(self):
        builder = ModuleTableBuilder(self._cursor)
        builder.createTable()

        self._cursor.execute("PRAGMA journal_mode = OFF")
        self._cursor.execute("PRAGMA synchronous = OFF")

        self.threadpool_do_tasks(ModuleMgr.__update_modules_info, "modules")

        builder.addMultiObjToDb(self.get_all())

        self._cursor.execute("drop view if exists binaries")
        self._cursor.execute("drop view if exists libraries")

        self._cursor.execute("DROP INDEX IF EXISTS index_modules")
        self._cursor.execute("CREATE INDEX index_modules ON modules (name)")

    def getCursor(self):
        return self._cursor

    def getDBFile(self):
        return self._db_file

    def connectDB(self):
        # must set a big timeout value, otherwise concurrent db operation may fail
        conn = sqlite3.connect(self._db_file, timeout=20)
        cursor = conn.cursor()
        return conn, cursor

    def close_db(self):
        self._conn.commit()
        self._cursor.close()
        self._conn.close()

    def check_symbols_count(self):
        for mod in self.get_all():
            provided, undefined = mod.symbols()
            if len(provided) != mod.get_provided_symbols_cnt(self._cursor) or len(undefined) != mod.get_needed_symbols_cnt(self._cursor):
                print("Module %d:%s symbols information incorrect:" % (mod["id"], mod["name"]))
                print("Provided: %d, in db: %d" % (len(provided), mod.get_provided_symbols_cnt(self._cursor)))
                print("Needed: %d, in db: %d" % (len(undefined), mod.get_needed_symbols_cnt(self._cursor)))

if __name__ == '__main__':
    mgr = ModuleMgr()

    mgr.scan_all_files()

    #mgr.add_all_symbols()

    #mgr.add_all_calls()

    #mgr.add_all_dependencies()

    mgr.add_all_modules()

    #mgr.check_symbols_count()
